{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e425a05c",
   "metadata": {},
   "source": [
    "## Probability Distributions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dcd7151e",
   "metadata": {},
   "source": [
    "author: Jacob Schreiber <br>\n",
    "contact: jmschreiber91@gmail.com\n",
    "\n",
    "Everything in pomegranate revolves around usage of probability distributions. Although these objects can be used by themselves, e.g., fit to data or given parameters and used to evaluate new examples, they are intended to be used as a part of a larger compositional model like a mixture or a hidden Markov model. Because everything in pomegranate is meant to be plug-and-play, this means that any probability distribution can be dropped into any other model. \n",
    "\n",
    "A key difference between distributions in pomegranate v1.0.0 and those in previous versions of pomegranate is that those in previous versions were usually univariate, in that one object represents one dimension, whereas in v1.0.0 and later each distribution is multivariate. If you wanted to model several dimensions in earlier versions you would have to use an `IndependentComponentsDistribution` with many distributions dropped in, but in v1.0.0 you would use a single distribution object. As an aside, `IndependentComponents` is still available in case you'd like to model different dimensions with different distributions. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "de49566f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numpy      : 1.23.4\n",
      "torch      : 1.13.0\n",
      "pomegranate: 1.0.0\n",
      "\n",
      "Compiler    : GCC 11.2.0\n",
      "OS          : Linux\n",
      "Release     : 4.15.0-208-generic\n",
      "Machine     : x86_64\n",
      "Processor   : x86_64\n",
      "CPU cores   : 8\n",
      "Architecture: 64bit\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import numpy\n",
    "import torch\n",
    "from pomegranate.distributions import *\n",
    "\n",
    "numpy.random.seed(0)\n",
    "numpy.set_printoptions(suppress=True)\n",
    "\n",
    "%load_ext watermark\n",
    "%watermark -m -n -p numpy,torch,pomegranate"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "321909d0",
   "metadata": {},
   "source": [
    "### Initialization and Fitting\n",
    "\n",
    "Let's first look at how to create a probability distribution. If you know what parameters you want to pass in, you can do that easily. These can be in the form of lists, tuples, numpy arrays, or torch tensors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "13805af8",
   "metadata": {},
   "outputs": [],
   "source": [
    "d1 = Normal([0.3, 0.7, 1.1], [1.1, 0.3, 1.8], covariance_type='diag')\n",
    "d2 = Exponential([0.8, 1.4, 4.1])\n",
    "d3 = Categorical([[0.3, 0.2, 0.5], [0.2, 0.1, 0.7]])\n",
    "\n",
    "d11 = Normal((0.3, 0.7, 1.1), (1.1, 0.3, 1.8), covariance_type='diag')\n",
    "d12 = Normal(numpy.array([0.3, 0.7, 1.1]), numpy.array([1.1, 0.3, 1.8]), covariance_type='diag')\n",
    "d13 = Normal(torch.tensor([0.3, 0.7, 1.1]), torch.tensor([1.1, 0.3, 1.8]), covariance_type='diag')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce70220e",
   "metadata": {},
   "source": [
    "If you don't have parameters you can learn them directly from data. Previously, this was done using the `Distribution.from_samples` method. However, because pomegranate v1.0.0 aims to be more like sklearn, learning directly from data should just be done using `fit`. This will derive the parameters using MLE from the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "24d644dd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Parameter containing:\n",
       " tensor([-0.0132, -0.0643,  0.0985]),\n",
       " Parameter containing:\n",
       " tensor([[ 0.8174,  0.0668, -0.0590],\n",
       "         [ 0.0668,  0.7918,  0.1045],\n",
       "         [-0.0590,  0.1045,  0.9713]]))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.manual_seed(0)\n",
    "X = torch.randn(100, 3)\n",
    "\n",
    "d4 = Normal().fit(X)\n",
    "d4.means, d4.covs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b7adb829",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[0.3500, 0.2000, 0.4500],\n",
       "        [0.3500, 0.3500, 0.3000],\n",
       "        [0.4500, 0.3500, 0.2000],\n",
       "        [0.2500, 0.3500, 0.4000]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X2 = torch.randint(3, size=(20, 4))\n",
    "\n",
    "d5 = Categorical().fit(X2)\n",
    "d5.probs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ee80ed6a",
   "metadata": {},
   "source": [
    "Similar to sklearn any hyperparameters used for training, such as regularization, will be passed into the initialization.  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "62024af5",
   "metadata": {},
   "source": [
    "### Probability and Log Probability\n",
    "\n",
    "All distributions can calculate probabilities and log probabilities using those respective methods."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6a7da803",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-3.9452, -3.2879, -5.3004, -3.6380, -3.9600, -4.9730, -3.2313, -5.4351,\n",
       "        -3.0938, -4.7396, -3.6861, -2.6550, -2.8112, -2.9265, -2.6482, -8.2887,\n",
       "        -3.7147, -2.6614, -2.8981, -9.5658, -6.2381, -3.2002, -5.7639, -6.9646,\n",
       "        -4.4075, -3.8988, -3.0689, -3.2529, -3.6521, -5.3077, -5.5544, -3.2166,\n",
       "        -5.6651, -7.9825, -2.6263, -2.6650, -3.4593, -6.5449, -2.8980, -3.0915,\n",
       "        -4.5713, -3.1680, -4.8918, -3.0811, -4.6555, -3.1913, -3.5364, -3.1703,\n",
       "        -2.5797, -3.4614, -2.5375, -4.8910, -2.9253, -3.9987, -3.0313, -3.2010,\n",
       "        -2.6444, -3.2952, -3.7149, -3.9957, -4.4953, -3.8348, -4.1071, -4.5762,\n",
       "        -2.9732, -2.9576, -3.4012, -3.4736, -3.9769, -3.7505, -4.5513, -4.0950,\n",
       "        -4.5067, -2.7840, -3.3281, -4.1321, -2.9699, -3.8536, -3.9683, -5.8055,\n",
       "        -5.3984, -4.9514, -2.7441, -3.8885, -4.5353, -3.0082, -2.8207, -3.3852,\n",
       "        -3.9225, -3.7536, -6.9391, -3.0570, -5.8579, -3.4830, -2.6783, -5.0286,\n",
       "        -2.9454, -3.4192, -3.8757, -4.4241])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d4.log_probability(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b28cdb26",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ -9.6459,  -2.8243, -16.0460,  -3.9857,  -2.9930, -11.8263,  -2.6130,\n",
       "         -6.0255,  -5.7931,  -3.3603,  -7.4735,  -5.1014,  -3.1393,  -3.6045,\n",
       "         -3.4457, -14.1638,  -2.9771,  -4.5638,  -5.1863,  -5.1922, -10.3993,\n",
       "         -5.5200,  -5.2215,  -7.2889,  -7.4847,  -7.9908,  -2.9989,  -6.8441,\n",
       "         -4.6477,  -4.3911,  -6.8748,  -3.9965, -10.5521, -22.9875,  -3.1194,\n",
       "         -2.8532,  -6.6198,  -8.0589,  -3.4627,  -7.2507,  -5.3280,  -3.2750,\n",
       "         -4.5530,  -6.5848,  -2.8317,  -4.6167,  -9.5592,  -5.2165,  -3.4062,\n",
       "         -3.2597,  -3.9544, -14.5495,  -3.4490,  -3.8333,  -3.5855,  -2.8570,\n",
       "         -3.3047,  -5.5304, -10.1993, -11.8056,  -3.3747,  -8.7955,  -3.4717,\n",
       "        -10.6717,  -3.0119,  -2.9799,  -3.2086,  -3.6065,  -8.6801,  -3.4716,\n",
       "         -4.2680,  -6.6669,  -4.1253,  -3.1685,  -5.0236,  -3.8058,  -3.1228,\n",
       "         -5.6273,  -3.9447,  -5.2440, -14.2746,  -4.6809,  -4.1667,  -3.5050,\n",
       "         -3.8123,  -4.4155,  -4.7357,  -5.1111,  -3.4382,  -6.3055,  -6.9832,\n",
       "         -2.7879,  -5.8146,  -3.9857,  -4.3523,  -5.0716,  -4.9841,  -6.1210,\n",
       "         -3.9729, -10.4107])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d1.log_probability(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9d25b42b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1.9348e-02, 3.7332e-02, 4.9895e-03, 2.6305e-02, 1.9064e-02, 6.9225e-03,\n",
       "        3.9505e-02, 4.3607e-03, 4.5328e-02, 8.7420e-03, 2.5069e-02, 7.0302e-02,\n",
       "        6.0135e-02, 5.3586e-02, 7.0780e-02, 2.5134e-04, 2.4362e-02, 6.9850e-02,\n",
       "        5.5130e-02, 7.0085e-05, 1.9536e-03, 4.0754e-02, 3.1389e-03, 9.4471e-04,\n",
       "        1.2186e-02, 2.0266e-02, 4.6471e-02, 3.8663e-02, 2.5936e-02, 4.9533e-03,\n",
       "        3.8702e-03, 4.0090e-02, 3.4649e-03, 3.4138e-04, 7.2346e-02, 6.9601e-02,\n",
       "        3.1452e-02, 1.4374e-03, 5.5134e-02, 4.5435e-02, 1.0344e-02, 4.2086e-02,\n",
       "        7.5080e-03, 4.5910e-02, 9.5090e-03, 4.1118e-02, 2.9118e-02, 4.1990e-02,\n",
       "        7.5798e-02, 3.1386e-02, 7.9060e-02, 7.5137e-03, 5.3647e-02, 1.8340e-02,\n",
       "        4.8254e-02, 4.0721e-02, 7.1046e-02, 3.7061e-02, 2.4357e-02, 1.8395e-02,\n",
       "        1.1162e-02, 2.1606e-02, 1.6455e-02, 1.0294e-02, 5.1139e-02, 5.1945e-02,\n",
       "        3.3334e-02, 3.1006e-02, 1.8743e-02, 2.3507e-02, 1.0554e-02, 1.6656e-02,\n",
       "        1.1035e-02, 6.1790e-02, 3.5861e-02, 1.6049e-02, 5.1307e-02, 2.1204e-02,\n",
       "        1.8905e-02, 3.0109e-03, 4.5240e-03, 7.0733e-03, 6.4304e-02, 2.0475e-02,\n",
       "        1.0723e-02, 4.9382e-02, 5.9563e-02, 3.3871e-02, 1.9791e-02, 2.3433e-02,\n",
       "        9.6916e-04, 4.7028e-02, 2.8573e-03, 3.0714e-02, 6.8677e-02, 6.5477e-03,\n",
       "        5.2580e-02, 3.2739e-02, 2.0740e-02, 1.1985e-02])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d4.probability(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c0e24986",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([6.4691e-05, 5.9351e-02, 1.0748e-07, 1.8580e-02, 5.0137e-02, 7.3099e-06,\n",
       "        7.3312e-02, 2.4163e-03, 3.0485e-03, 3.4725e-02, 5.6796e-04, 6.0884e-03,\n",
       "        4.3315e-02, 2.7200e-02, 3.1883e-02, 7.0588e-07, 5.0939e-02, 1.0422e-02,\n",
       "        5.5924e-03, 5.5596e-03, 3.0455e-05, 4.0060e-03, 5.3992e-03, 6.8307e-04,\n",
       "        5.6159e-04, 3.3856e-04, 4.9844e-02, 1.0658e-03, 9.5840e-03, 1.2387e-02,\n",
       "        1.0335e-03, 1.8379e-02, 2.6139e-05, 1.0391e-10, 4.4183e-02, 5.7661e-02,\n",
       "        1.3337e-03, 3.1628e-04, 3.1344e-02, 7.0967e-04, 4.8536e-03, 3.7815e-02,\n",
       "        1.0535e-02, 1.3812e-03, 5.8910e-02, 9.8850e-03, 7.0553e-05, 5.4264e-03,\n",
       "        3.3168e-02, 3.8402e-02, 1.9170e-02, 4.7997e-07, 3.1776e-02, 2.1639e-02,\n",
       "        2.7723e-02, 5.7439e-02, 3.6711e-02, 3.9645e-03, 3.7195e-05, 7.4626e-06,\n",
       "        3.4229e-02, 1.5142e-04, 3.1063e-02, 2.3192e-05, 4.9196e-02, 5.0799e-02,\n",
       "        4.0414e-02, 2.7148e-02, 1.6994e-04, 3.1066e-02, 1.4009e-02, 1.2723e-03,\n",
       "        1.6158e-02, 4.2068e-02, 6.5808e-03, 2.2242e-02, 4.4035e-02, 3.5983e-03,\n",
       "        1.9357e-02, 5.2789e-03, 6.3185e-07, 9.2710e-03, 1.5503e-02, 3.0046e-02,\n",
       "        2.2097e-02, 1.2088e-02, 8.7762e-03, 6.0295e-03, 3.2123e-02, 1.8262e-03,\n",
       "        9.2731e-04, 6.1549e-02, 2.9838e-03, 1.8579e-02, 1.2878e-02, 6.2726e-03,\n",
       "        6.8462e-03, 2.1963e-03, 1.8819e-02, 3.0109e-05])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d1.probability(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e5be6cb8",
   "metadata": {},
   "source": [
    " Similar to initialization, these can be lists, numpy arrays, or torch tensors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5cc0c465",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_88680/2372227.py:1: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  d4.log_probability(torch.tensor(X))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([-3.9452, -3.2879, -5.3004, -3.6380, -3.9600, -4.9730, -3.2313, -5.4351,\n",
       "        -3.0938, -4.7396, -3.6861, -2.6550, -2.8112, -2.9265, -2.6482, -8.2887,\n",
       "        -3.7147, -2.6614, -2.8981, -9.5658, -6.2381, -3.2002, -5.7639, -6.9646,\n",
       "        -4.4075, -3.8988, -3.0689, -3.2529, -3.6521, -5.3077, -5.5544, -3.2166,\n",
       "        -5.6651, -7.9825, -2.6263, -2.6650, -3.4593, -6.5449, -2.8980, -3.0915,\n",
       "        -4.5713, -3.1680, -4.8918, -3.0811, -4.6555, -3.1913, -3.5364, -3.1703,\n",
       "        -2.5797, -3.4614, -2.5375, -4.8910, -2.9253, -3.9987, -3.0313, -3.2010,\n",
       "        -2.6444, -3.2952, -3.7149, -3.9957, -4.4953, -3.8348, -4.1071, -4.5762,\n",
       "        -2.9732, -2.9576, -3.4012, -3.4736, -3.9769, -3.7505, -4.5513, -4.0950,\n",
       "        -4.5067, -2.7840, -3.3281, -4.1321, -2.9699, -3.8536, -3.9683, -5.8055,\n",
       "        -5.3984, -4.9514, -2.7441, -3.8885, -4.5353, -3.0082, -2.8207, -3.3852,\n",
       "        -3.9225, -3.7536, -6.9391, -3.0570, -5.8579, -3.4830, -2.6783, -5.0286,\n",
       "        -2.9454, -3.4192, -3.8757, -4.4241])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d4.log_probability(torch.tensor(X))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b314697",
   "metadata": {},
   "source": [
    "### Summarization\n",
    "\n",
    "Although the primary way to learn parameters from data is to use the `fit` method, the underlying engine for this learning is a pair of operations: `summarize` and `from_summaries`. In `summarize`, the data is condensed into additive sufficient statistics that can be summed across batches. For instance:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "49587093",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-3.0192, -0.8312,  2.1886])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d = Normal()\n",
    "d.summarize(X[:5])\n",
    "\n",
    "d._xw_sum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "543d0533",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-1.3155, -6.4282,  9.8475])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d.summarize(X[5:])\n",
    "\n",
    "d._xw_sum"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f0c8be5",
   "metadata": {},
   "source": [
    "These values would be the same if we had summarized the entire data set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "2e035557",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-1.3155, -6.4282,  9.8475])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d2 = Normal()\n",
    "d2.summarize(X)\n",
    "\n",
    "d2._xw_sum"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe1fe2b3",
   "metadata": {},
   "source": [
    "From these values, usually stored as `_w_sum` and `_xw_sum`, one can perfectly recreate the values you would get if you fit to the entire data set. You can do this with the `from_summaries` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "d05fe160",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Parameter containing:\n",
       " tensor([-0.0132, -0.0643,  0.0985]),\n",
       " Parameter containing:\n",
       " tensor([-0.0132, -0.0643,  0.0985]))"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d.from_summaries()\n",
    "d2.from_summaries()\n",
    "\n",
    "d.means, d2.means"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f31caa1d",
   "metadata": {},
   "source": [
    "We will explore these ideas more in other tutorials, and specifically how this allows us to trivially implement batching schemes for out-of-core learning and for how to fit to large data sets using limited GPU memory."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
